"""
measure_wilson.py
~~~~~~~~~~~~~~~~~~

Compute Wilson loops on a two‑dimensional lattice for one or more gauge
groups.  For each loop size ``L`` the average value of the loop operator is
measured by sliding an ``L×L`` rectangular contour over the lattice and
multiplying the link variables along its edges.  For non‑Abelian gauge
groups the trace of the resulting product matrix is taken; for U(1) the
matrix is one‑dimensional and the trace reduces to the complex number itself.

The results are written to CSV files named ``wilson_<group>.csv`` in the
``results_dir`` specified by the configuration.  Each file has columns
``size`` (loop size), ``real`` and ``imag`` for the average complex value
of the loop operator.
"""

from __future__ import annotations

import os
import yaml
import numpy as np
import pandas as pd
from typing import Dict, Tuple, List


def build_link_index_map(lattice: np.ndarray) -> Dict[Tuple[int, int, int], int]:
    """Construct a mapping from ``(x, y, μ)`` to the index in the lattice array.

    Parameters
    ----------
    lattice : numpy.ndarray
        Array of link descriptors returned by ``build_lattice``.  Each
        element is a tuple ``((x, y), μ)``.

    Returns
    -------
    dict
        Mapping from ``(x, y, μ)`` to the integer index of that link.
    """
    mapping: Dict[Tuple[int, int, int], int] = {}
    for idx, link in enumerate(lattice):
        (coords, mu) = link
        x, y = coords
        mapping[(int(x), int(y), int(mu))] = idx
    return mapping


def compute_wilson_loop_average(
    U: np.ndarray,
    mapping: Dict[Tuple[int, int, int], int],
    lattice_size: int,
    loop_size: int,
    gauge_group: str,
) -> complex:
    """Compute the average Wilson loop of a given size for a gauge group.

    Parameters
    ----------
    U : numpy.ndarray
        Link variables for the gauge group.  For U(1) this is a one‑dimensional
        complex array of length ``num_links``.  For SU(2) and SU(3) it is an
        array of shape ``(num_links, N, N)`` with ``N`` the dimension of the
        representation (2 or 3).  All matrices are assumed to be diagonal.
    mapping : dict
        Mapping from ``(x, y, μ)`` to link index.
    lattice_size : int
        Number of sites along each dimension of the square lattice.
    loop_size : int
        Size of the square loop (number of links along each side).
    gauge_group : str
        One of ``'U1'``, ``'SU2'``, ``'SU3'`` determining how traces and
        inverses are computed.

    Returns
    -------
    complex
        The average value of the Wilson loop operator for the given gauge and
        loop size.
    """
    total: complex = 0.0 + 0.0j
    count: int = 0
    is_u1 = (gauge_group.upper() == 'U1')
    # Determine matrix dimension for trace and inverse
    if is_u1:
        # U is shape (num_links,) of complex numbers
        pass
    else:
        # SU2 or SU3; shape (num_links, N, N)
        dim = U.shape[1]
    L = loop_size
    N = lattice_size
    for x in range(N):
        for y in range(N):
            # Start at (x, y)
            # Compute path along +x direction
            product = None
            # Forward along +x
            for s in range(L):
                x1 = (x + s) % N
                y1 = y
                idx = mapping[(x1, y1, 0)]
                if is_u1:
                    val = U[idx]
                    product = val if product is None else product * val
                else:
                    val = U[idx]
                    product = val.copy() if product is None else product @ val
            # Forward along +y from (x+L, y)
            for s in range(L):
                x1 = (x + L) % N
                y1 = (y + s) % N
                idx = mapping[(x1, y1, 1)]
                if is_u1:
                    product *= U[idx]
                else:
                    product = product @ U[idx]
            # Backward along -x from (x+L, y+L).  Inverse of link along μ=0
            for s in range(L):
                x1 = (x + L - s) % N
                y1 = (y + L) % N
                idx = mapping[(x1 % N, y1 % N, 0)]
                # For backward direction use Hermitian conjugate (inverse) of U
                if is_u1:
                    product *= np.conjugate(U[idx])
                else:
                    product = product @ U[idx].conj().T
            # Backward along -y from (x, y+L)
            for s in range(L):
                x1 = x
                y1 = (y + L - s) % N
                idx = mapping[(x1 % N, y1 % N, 1)]
                if is_u1:
                    product *= np.conjugate(U[idx])
                else:
                    product = product @ U[idx].conj().T
            # At this point product is the loop holonomy
            if is_u1:
                total += product
            else:
                # Take trace of the matrix
                total += np.trace(product)
            count += 1
    # Compute average
    return total / count


def main(config_path: str = 'config.yaml') -> None:
    """Entry point for measuring Wilson loops and writing results.

    Reads configuration to determine which gauge groups to process and
    corresponding loop sizes.  For each gauge group present in the data
    directory ``data_dir`` a CSV of Wilson loop averages is written to
    ``results_dir/wilson_<group>.csv``.
    """
    cfg_file = config_path if os.path.isabs(config_path) else os.path.abspath(config_path)
    if not os.path.exists(cfg_file):
        raise FileNotFoundError(f"Cannot find configuration file: {config_path}")
    with open(cfg_file) as f:
        cfg = yaml.safe_load(f)
    base_dir = os.path.dirname(cfg_file)
    data_dir_cfg = cfg.get('data_dir', 'data')
    data_dir = data_dir_cfg if os.path.isabs(data_dir_cfg) else os.path.join(base_dir, data_dir_cfg)
    results_dir_cfg = cfg.get('results_dir', 'results')
    results_dir = results_dir_cfg if os.path.isabs(results_dir_cfg) else os.path.join(base_dir, results_dir_cfg)
    os.makedirs(results_dir, exist_ok=True)
    # Lattice
    lattice_path = os.path.join(data_dir, cfg.get('lattice_file', 'lattice.npy'))
    lattice = np.load(lattice_path, allow_pickle=True)
    lattice_size = int(cfg.get('lattice_size', 4))
    mapping = build_link_index_map(lattice)
    # Loop sizes to measure
    loop_sizes: List[int] = list(cfg.get('loop_sizes', []))
    # Gauge groups
    gauge_groups: List[str] = [g.upper() for g in cfg.get('gauge_groups', ['U1'])]
    # Process each gauge group
    for group in gauge_groups:
        # Determine input file
        if group == 'U1':
            path_U = os.path.join(data_dir, 'U_U1.npy')
        elif group == 'SU2':
            path_U = os.path.join(data_dir, 'U_SU2.npy')
        elif group == 'SU3':
            path_U = os.path.join(data_dir, 'U_SU3.npy')
        else:
            continue
        if not os.path.exists(path_U):
            # Skip missing gauge group
            continue
        U = np.load(path_U, allow_pickle=True)
        rows = []
        for L in loop_sizes:
            avg = compute_wilson_loop_average(U, mapping, lattice_size, L, group)
            rows.append({'size': L, 'real': float(np.real(avg)), 'imag': float(np.imag(avg))})
        df = pd.DataFrame(rows)
        df.to_csv(os.path.join(results_dir, f'wilson_{group}.csv'), index=False)
    return None


if __name__ == '__main__':
    import sys
    cfg = sys.argv[1] if len(sys.argv) > 1 else 'config.yaml'
    main(cfg)